from typing import Any, Callable, Optional

from core.agents.response import AgentResponse
from core.config import get_config
from core.db.models import ProjectState
from core.llm.base import BaseLLMClient, LLMError
from core.log import get_logger
from core.proc.process_manager import ProcessManager
from core.state.state_manager import StateManager
from core.ui.base import AgentSource, UIBase, UserInput, pythagora_source

log = get_logger(__name__)


class BaseAgent:
    """
    Base class for agents.
    """

    agent_type: str
    display_name: str

    def __init__(
        self,
        state_manager: StateManager,
        ui: UIBase,
        *,
        step: Optional[Any] = None,
        prev_response: Optional["AgentResponse"] = None,
        process_manager: Optional["ProcessManager"] = None,
        data: Optional[Any] = None,
        args: Optional[Any] = None,
    ):
        """
        Create a new agent.
        """
        self.ui_source = AgentSource(self.display_name, self.agent_type)
        self.ui = ui
        self.state_manager = state_manager
        self.process_manager = process_manager
        self.prev_response = prev_response
        self.step = step
        self.data = data
        self.args = args

    @property
    def current_state(self) -> ProjectState:
        """Current state of the project (read-only)."""
        return self.state_manager.current_state

    @property
    def next_state(self) -> ProjectState:
        """Next state of the project (write-only)."""
        return self.state_manager.next_state

    async def send_message(self, message: str, extra_info: Optional[dict] = None):
        """
        Send a message to the user.

        Convenience method, uses `UIBase.send_message()` to send the message,
        setting the correct source and project state ID.

        :param message: Message to send.
        :param extra_info: Extra information to indicate special functionality in extension
        """
        await self.ui.send_message(
            message + "\n", source=self.ui_source, project_state_id=str(self.current_state.id), extra_info=extra_info
        )

    async def ask_question(
        self,
        question: str,
        *,
        buttons: Optional[dict[str, str]] = None,
        default: Optional[str] = None,
        buttons_only: bool = False,
        allow_empty: bool = False,
        full_screen: Optional[bool] = False,
        hint: Optional[str] = None,
        verbose: bool = True,
        initial_text: Optional[str] = None,
        extra_info: Optional[dict] = None,
        placeholder: Optional[str] = None,
    ) -> UserInput:
        """
        Ask a question to the user and return the response.

        Convenience method, uses `UIBase.ask_question()` to
        ask the question, setting the correct source and project state ID, and
        logging the question/response.

        :param question: Question to ask.
        :param buttons: Buttons to display with the question.
        :param default: Default button to select.
        :param buttons_only: Only display buttons, no text input.
        :param allow_empty: Allow empty input.
        :param full_screen: Show question full screen in extension.
        :param hint: Text to display in a popup as a hint to the question.
        :param verbose: Whether to log the question and response.
        :param initial_text: Initial text input.
        :param extra_info: Extra information to indicate special functionality in extension.
        :param placeholder: Placeholder text for the input field.
        :return: User response.
        """
        response = await self.ui.ask_question(
            question,
            buttons=buttons,
            default=default,
            buttons_only=buttons_only,
            allow_empty=allow_empty,
            full_screen=full_screen,
            hint=hint,
            verbose=verbose,
            initial_text=initial_text,
            source=self.ui_source,
            project_state_id=str(self.current_state.id) if self.current_state.prev_state_id is not None else None,
            extra_info=extra_info,
            placeholder=placeholder,
        )
        # Store the access token in the state manager
        if hasattr(response, "access_token") and response.access_token:
            self.state_manager.update_access_token(response.access_token)
        await self.state_manager.log_user_input(question, response)
        return response

    async def stream_handler(self, content: str):
        """
        Handle streamed response from the LLM.

        Serves as a callback to `AgentBase.llm()` so it can stream the responses to the UI.

        :param content: Response content.
        """
        route = getattr(self, "_current_route", None)
        await self.ui.send_stream_chunk(
            content, source=self.ui_source, project_state_id=str(self.current_state.id), route=route
        )

        if content is None:
            await self.ui.send_message("", source=self.ui_source, project_state_id=str(self.current_state.id))

    async def error_handler(self, error: LLMError, message: Optional[str] = None) -> bool:
        """
        Handle error responses from the LLM.

        :param error: The exception that was thrown the the LLM client.
        :param message: Optional message to show.
        :return: Whether the request should be retried.
        """

        if error == LLMError.KEY_EXPIRED:
            await self.ui.send_key_expired(message)
            answer = await self.ask_question(
                "Would you like to retry the last step?",
                buttons={"yes": "Yes", "no": "No"},
                buttons_only=True,
            )
            if answer.button == "yes":
                return True
        elif error == LLMError.GENERIC_API_ERROR:
            await self.stream_handler(message)
            answer = await self.ui.ask_question(
                "Would you like to retry the failed request?",
                buttons={"yes": "Yes", "no": "No"},
                buttons_only=True,
                source=pythagora_source,
            )
            if answer.button == "yes":
                return True
        elif error == LLMError.RATE_LIMITED:
            await self.stream_handler(message)

        return False

    def get_llm(self, name=None, stream_output=False, route=None) -> Callable:
        """
        Get a new instance of the agent-specific LLM client.

        The client initializes the UI stream handler and stores the
        request/response to the current state's log. The agent name
        can be overridden in case the agent needs to use a different
        model configuration.

        :param name: Name of the agent for configuration (default: class name).
        :param stream_output: Whether to enable streaming output.
        :param route: Route information for message routing.
        :return: LLM client for the agent.
        """

        if name is None:
            name = self.__class__.__name__

        config = get_config()

        llm_config = config.llm_for_agent(name)
        client_class = BaseLLMClient.for_provider(llm_config.provider)
        stream_handler = self.stream_handler if stream_output else None
        llm_client = client_class(
            llm_config,
            stream_handler=stream_handler,
            error_handler=self.error_handler,
            ui=self.ui,
            state_manager=self.state_manager,
        )

        async def client(convo, **kwargs) -> Any:
            """
            Agent-specific LLM client.

            For details on optional arguments to pass to the LLM client,
            see `pythagora.llm.openai_client.OpenAIClient()`.
            """
            # Set the route for this LLM request
            self._current_route = route
            try:
                response, request_log = await llm_client(convo, **kwargs)
                await self.state_manager.log_llm_request(request_log, agent=self)
                return response
            finally:
                # Clear the route after the request
                self._current_route = None

        return client

    async def run() -> AgentResponse:
        """
        Run the agent.

        :return: Response from the agent.
        """
        raise NotImplementedError()
